import os, sys
sys.path.append(os.getcwd())

from src.dataset import create_dataset

if __name__ == "__main__":
    ds = create_dataset(8, data_path="output/",
                        data_start_index=1, eod_reset=1, full_batch=bool(1), # data_start_index=0 for dev_set
                        eod_id=50256, device_num=1, rank=0,
                        column_name='input_ids', epoch=1)
    step_per_epoch = ds.get_dataset_size()
    print()